from __future__ import print_function
import torch
import argparse

from utils import load_data
from utils import anchor_random_crop
from utils import save_model
from utils import scheduler_step
from network_and_loss import FeatNet
from network_and_loss import PWConLoss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

parser = argparse.ArgumentParser(description='FeatNet training with PWConLoss')

# Hyperparameters 
parser.add_argument('--task', type=str, default="COFW", help='Task dataset')
parser.add_argument('--random_scale', default=True, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=True, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')

parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
parser.add_argument('--num_epochs', type=int, default=2000, help='Maximum number of epochs')
parser.add_argument('--learning_rate', type=float, default=1E-2, help='Model learning rate')
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='Learning rate decay rate')
parser.add_argument('--lr_decay_interval', type=int, default=1500, help='Learning rate decay interval')

args = parser.parse_args()

def main():
    torch.cuda.empty_cache()
    train_loader, _ = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    featnet_leye = FeatNet().to(device)
    featnet_reye = FeatNet().to(device)
    featnet_mouth = FeatNet().to(device)
    featnet_nose = FeatNet().to(device)
    featnet_jaw = FeatNet().to(device)
    
    criterion = PWConLoss().to(device)
    optimizer = torch.optim.Adam([
        {'params': featnet_leye.parameters(), 'lr': args.learning_rate},
        {'params': featnet_reye.parameters(), 'lr': args.learning_rate},
        {'params': featnet_mouth.parameters(), 'lr': args.learning_rate},
        {'params': featnet_nose.parameters(), 'lr': args.learning_rate},
        {'params': featnet_jaw.parameters(), 'lr': args.learning_rate}])
        
    for epoch in range(args.num_epochs):
        if epoch != 0 :
            scheduler_step(optimizer, epoch, args.lr_decay_interval, args.lr_decay_rate)
        
        loss_leye, loss_reye, loss_mouth, loss_nose, loss_jaw \
            = train(featnet_leye, featnet_reye, featnet_mouth, featnet_nose, featnet_jaw, 
                    train_loader, criterion, optimizer)
        
        print("\nEpoch: {}/{}.. ".format(epoch+1, args.num_epochs).ljust(14),
              "PWConLoss_leye: {:.3f}.. ".format(loss_leye).ljust(14),
              "PWConLoss_reye: {:.3f}.. ".format(loss_reye).ljust(14),
              "PWConLoss_mouth: {:.3f}.. ".format(loss_mouth).ljust(14),
              "PWConLoss_nose: {:.3f}.. ".format(loss_nose).ljust(14),
              "PWConLoss_jaw: {:.3f}.. ".format(loss_jaw).ljust(14)) 
        
        if epoch % 500 == 499 : 
            save_model("FeatNet", featnet_leye, featnet_reye, featnet_mouth, 
                       featnet_nose, featnet_jaw, optimizer, epoch+1)
        
        
        
def train(featnet_leye, featnet_reye, featnet_mouth, 
          featnet_nose, featnet_jaw, train_loader, criterion, optimizer):
    featnet_leye.train()
    featnet_reye.train()
    featnet_mouth.train()
    featnet_nose.train()
    featnet_jaw.train()
    
    loss_leye = 0
    loss_reye = 0
    loss_mouth = 0
    loss_nose = 0
    loss_jaw = 0
    
    for i, (images, landmark_coords) in enumerate(train_loader) :
        images, landmark_coords = images.to(device), landmark_coords.to(device)
        B = images.size(0)
        landmark_coords = landmark_coords.view(-1, 68, 2)
        leye_coords = torch.cat((landmark_coords[:, 17:22], landmark_coords[:, 36:42]), dim=1)
        reye_coords = torch.cat((landmark_coords[:, 22:27], landmark_coords[:, 42:48]), dim=1)
        mouth_coords = landmark_coords[:, 48:68]
        nose_coords = landmark_coords
        jaw_coords = torch.cat((landmark_coords[:, 0:17], landmark_coords[:, 27:36]), dim=1)
        
        featnet_leye.zero_grad()
        featnet_reye.zero_grad()
        featnet_mouth.zero_grad()
        featnet_nose.zero_grad()
        featnet_jaw.zero_grad()
        optimizer.zero_grad()
        
        
        anchor_view, random_view, landmark_select, \
            a_l_distance, a_l_relationship, r_l_distance, r_l_relationship \
                = anchor_random_crop(images, leye_coords)
        anchor_random = torch.cat((anchor_view, random_view), 0)
        z = featnet_leye(anchor_random)
        z1, z2 = torch.split(z, [B, B], dim=0)
        projections = torch.cat((z1.unsqueeze(1), z2.unsqueeze(1)), dim=1)
        pwconloss = criterion(projections, landmark_select, 
                              a_l_distance, a_l_relationship, 
                              r_l_distance, r_l_relationship)
        loss_leye += pwconloss.item() / len(train_loader)
        pwconloss.backward()
        
        
        anchor_view, random_view, landmark_select, \
            a_l_distance, a_l_relationship, r_l_distance, r_l_relationship \
                = anchor_random_crop(images, reye_coords)
        anchor_random = torch.cat((anchor_view, random_view), 0)
        z = featnet_reye(anchor_random)
        z1, z2 = torch.split(z, [B, B], dim=0)
        projections = torch.cat((z1.unsqueeze(1), z2.unsqueeze(1)), dim=1)
        pwconloss = criterion(projections, landmark_select, 
                              a_l_distance, a_l_relationship, 
                              r_l_distance, r_l_relationship)
        loss_reye += pwconloss.item() / len(train_loader)
        pwconloss.backward()
        
        
        anchor_view, random_view, landmark_select, \
            a_l_distance, a_l_relationship, r_l_distance, r_l_relationship \
                = anchor_random_crop(images, mouth_coords)
        anchor_random = torch.cat((anchor_view, random_view), 0)
        z = featnet_mouth(anchor_random)
        z1, z2 = torch.split(z, [B, B], dim=0)
        projections = torch.cat((z1.unsqueeze(1), z2.unsqueeze(1)), dim=1)
        pwconloss = criterion(projections, landmark_select, 
                              a_l_distance, a_l_relationship, 
                              r_l_distance, r_l_relationship)
        loss_mouth += pwconloss.item() / len(train_loader)
        pwconloss.backward()
            
        
        anchor_view, random_view, landmark_select, \
            a_l_distance, a_l_relationship, r_l_distance, r_l_relationship \
                = anchor_random_crop(images, nose_coords)
        anchor_random = torch.cat((anchor_view, random_view), 0)
        z = featnet_nose(anchor_random)
        z1, z2 = torch.split(z, [B, B], dim=0)
        projections = torch.cat((z1.unsqueeze(1), z2.unsqueeze(1)), dim=1)
        pwconloss = criterion(projections, landmark_select, 
                              a_l_distance, a_l_relationship, 
                              r_l_distance, r_l_relationship)
        loss_nose += pwconloss.item() / len(train_loader)
        pwconloss.backward()
            
        
        anchor_view, random_view, landmark_select, \
            a_l_distance, a_l_relationship, r_l_distance, r_l_relationship \
                = anchor_random_crop(images, jaw_coords)
        anchor_random = torch.cat((anchor_view, random_view), 0)
        z = featnet_jaw(anchor_random)
        z1, z2 = torch.split(z, [B, B], dim=0)
        projections = torch.cat((z1.unsqueeze(1), z2.unsqueeze(1)), dim=1)
        pwconloss = criterion(projections, landmark_select, 
                              a_l_distance, a_l_relationship, 
                              r_l_distance, r_l_relationship)
        loss_jaw += pwconloss.item() / len(train_loader)
        pwconloss.backward()
        
        optimizer.step()
        
    return loss_leye, loss_reye, loss_mouth, loss_nose, loss_jaw



if __name__=='__main__':
    main()
    
    
    
    
    
